"""
LLM 智能体的上下文管理器。
作者: Kangrui Wang, Zihan Wang
日期: 2025-03-30
"""

from dis import pretty_flags
from itertools import zip_longest

import torch
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from dataclasses import dataclass
import re
from verl import DataProto
from verl.utils.dataset.rl_dataset import collate_fn
from transformers import AutoTokenizer
import hydra
from ragen.utils import register_resolvers
from ragen.env import REGISTERED_ENV_CONFIGS
from tensordict import TensorDict
import json
from json_repair import repair_json  # 导入JSON修复模块
from dataclasses import asdict

register_resolvers()


def get_special_tokens(tokenizer: AutoTokenizer):
    if "qwen" in tokenizer.name_or_path.lower():
        special_token = tokenizer.encode("<|im_start|>")[0]
        reward_token = tokenizer.encode("<|im_end|>")[0]
    elif "llama-3" in tokenizer.name_or_path.lower():
        special_token = 128006
        reward_token = 128009
    else:
        raise ValueError(f"Unsupported model: {tokenizer.name_or_path}")
    return special_token, reward_token


def get_masks_and_scores(
    input_ids: torch.Tensor,
    tokenizer: AutoTokenizer,
    all_scores: List[List[float]] = None,
    use_turn_scores: bool = False,
    enable_response_mask: bool = False,
):
    """
    input_ids: shape (bsz, seq_len)
    Get loss mask that only learns between <|im_start|>assistant and <|im_end|>. Currently only supports qwen.
    NOTE: important! This assumes that the input_ids starts with system and then user & assistant in alternative ways
    """
    special_token, reward_token = get_special_tokens(tokenizer)

    turn_starts = torch.where(input_ids == special_token, 1, 0)
    turn_indicators = torch.cumsum(turn_starts, dim=-1)
    if enable_response_mask:
        loss_mask = (turn_indicators % 2 == 1) & (
            turn_indicators > 1
        )  # only learns all assistant turns
    else:
        loss_mask = turn_indicators > 1  # learns everything after system prompt
    response_mask = (turn_indicators % 2 == 1) & (turn_indicators > 1)

    score_tensor = torch.zeros_like(input_ids, dtype=torch.float32)
    if use_turn_scores:
        for idx, scores in enumerate(zip_longest(*all_scores, fillvalue=0)):
            scores = torch.tensor(scores, dtype=torch.float32)
            turn_indicator = (
                idx * 2 + 3
            )  # 0: pad. 1: system. 2+2n: user. 3+2n: assistant
            reward_position = (input_ids == reward_token) & (
                turn_indicators == turn_indicator
            )
            # Set the last token of the rows where all positions are False to True
            reward_position[~reward_position.any(dim=-1), -1] = True
            score_tensor[reward_position] = scores
        if "qwen" in tokenizer.name_or_path.lower():
            # for Qwen, there is a "\n" between special token and reward token, so we shift this to make sure reward is assigned to the last token of a turn
            score_tensor = score_tensor.roll(shifts=1, dims=-1)
    else:
        scores = [sum(i) for i in all_scores]
        score_tensor[:, -1] = torch.tensor(scores, dtype=torch.float32)
    score_tensor = score_tensor[:, 1:]  # remove the first token
    loss_mask = loss_mask[:, :-1]  # remove the last token
    response_mask = response_mask[:, :-1]  # remove the last token

    return score_tensor, loss_mask, response_mask


class ContextManager:
    """
    管理 LLM 与环境交互的上下文。
    负责在环境输出与 LLM 输入之间进行转换（双向）。
    """

    def __init__(
        self,
        config,
        tokenizer,
        processor=None,
        mode: str = "train",
    ):
        """
        初始化 ContextManager。
        processor 用于处理图像数据。
        """
        self.config = config
        self.tokenizer = tokenizer
        self.processor = processor
        self.action_sep = self.config.agent_proxy.action_sep
        self.special_token_list = [
            "<think>",
            "</think>",
            "<answer>",
            "</answer>",
            "<|im_start|>",
            "<|im_end|>",
        ]

        self.es_cfg = self.config.es_manager[mode]
        self.env_nums = {
            env_tag: n_group * self.es_cfg.group_size
            for n_group, env_tag in zip(
                self.es_cfg.env_configs.n_groups, self.es_cfg.env_configs.tags
            )
        }
        self._init_prefix_lookup()

    def _check_env_installed(self, env_type: str):
        if env_type not in REGISTERED_ENV_CONFIGS:
            raise ValueError(
                f"Environment {env_type} is not installed. Please install it using the scripts/setup_{env_type}.sh script."
            )

    def _init_prefix_lookup(self):
        prefix_lookup = {}
        prefixes = {}
        env_config_lookup = {}
        env_config = {}
        for env_tag, env_config in self.config.custom_envs.items():
            if env_tag not in self.es_cfg.env_configs.tags:
                continue

            self._check_env_installed(env_config.env_type)
            env_config_new = asdict(REGISTERED_ENV_CONFIGS[env_config.env_type]())
            for k, v in env_config.items():
                env_config_new[k] = v
            env_instruction = env_config_new.get("env_instruction", "")

            # 可选：提供网格词表与动作列表
            if env_config_new.get("grid_vocab", False):
                grid_vocab_str = (
                    "\nThe meaning of each symbol in the state is:\n"
                    + ", ".join(
                        [f"{k}: {v}" for k, v in env_config_new["grid_vocab"].items()]
                    )
                )
                env_instruction += grid_vocab_str
            if env_config_new.get("action_lookup", False):
                action_lookup_str = "\nYour available actions are:\n" + ", ".join(
                    [f"{v}" for k, v in env_config_new["action_lookup"].items()]
                )
                action_lookup_str += (
                    f"\nYou can make up to {env_config_new['max_actions_per_traj']} actions, separated by the action separator \" "
                    + self.action_sep
                    + ' "\n'
                )
                env_instruction += action_lookup_str
            prefixes[env_tag] = env_instruction
            # 中间回合与最终回合提示
            if "mid_turn_instruction" in env_config_new:
                prefixes[f"{env_tag}_mid"] = env_config_new["mid_turn_instruction"]
            if "final_turn_instruction" in env_config_new:
                prefixes[f"{env_tag}_final"] = env_config_new["final_turn_instruction"]

            env_config_lookup[env_tag] = {
                "max_tokens": env_config.get(
                    "max_tokens", self.config.actor_rollout_ref.rollout.response_length
                )
            }

        tags = self.es_cfg.env_configs.tags
        n_groups = self.es_cfg.env_configs.n_groups
        group_size = self.es_cfg.group_size

        cur_group = 0
        for env_tag, n_group in zip(tags, n_groups):
            env_instruction = prefixes[env_tag]
            start_idx = cur_group * group_size
            end_idx = (cur_group + n_group) * group_size
            # 最终回合提示（必用）
            final_turn_instruction = prefixes.get(f"{env_tag}_final", env_instruction)
            # 中间回合提示（可选）
            mid_turn_instruction = prefixes.get(f"{env_tag}_mid", "")

            for i in range(start_idx, end_idx):
                prefix_lookup[i] = env_instruction
                env_config_lookup[i] = env_config_lookup[env_tag]
                # 存储最终回合与中间回合前缀
                prefix_lookup[f"{i}_final"] = final_turn_instruction
                if mid_turn_instruction:
                    prefix_lookup[f"{i}_mid"] = mid_turn_instruction

            cur_group += n_group

        self.prefix_lookup = prefix_lookup
        self.env_config_lookup = env_config_lookup

    def _parse_response(self, response: str) -> Tuple[str, List[str], str, str]:
        """
        解析LLM响应，支持多种格式：
        1. JSON数组格式的工具调用
        2. 传统的分隔符格式
        3. 思考-回答格式

        参数：
        response: LLM响应文本

        返回：
        (llm_response, actions, think_content, answer_content)
        - llm_response: 格式化后的LLM响应
        - actions: 解析出的动作列表
        - think_content: 思考内容
        - answer_content: 回答内容
        """
        # 初始化变量，确保总是有定义
        think_content = ""
        answer_content = response
        actions = []
        llm_response = response
        
        # 根据是否启用思考功能选择匹配模式
        pattern = (
            r"<think>(.*?)</think>\s*<answer>(.*?)</answer>"
            if self.config.agent_proxy.enable_think
            else r"<answer>(.*?)</answer>"
        )
        match = re.search(pattern, response, re.DOTALL)

        if not match:  # 如果没有匹配
            # 不移除这类无效字符串，但记录警告
            # 使用预初始化的变量，无需重新赋值
            pass
        else:
            # 提取思考内容和动作内容
            if self.config.agent_proxy.enable_think:
                think_content, answer_content = match.group(1), match.group(2)
            else:
                think_content, answer_content = "", match.group(1)

            # 移除特殊标记
            for special_token in self.special_token_list:
                answer_content = answer_content.replace(special_token, "").strip()
                think_content = think_content.replace(special_token, "").strip()

            # 尝试解析工具调用JSON
            actions = []
            try:
                # 清理注释和多余空白
                cleaned_content = re.sub(r"//.*", "", answer_content)
                cleaned_content = cleaned_content.strip()

                # 处理可能的多行JSON
                if cleaned_content.startswith("{") and cleaned_content.endswith("}"):
                    try:
                        # 使用json_repair修复JSON字符串
                        repaired_content = repair_json(
                            cleaned_content, ensure_ascii=False
                        )

                        # 解析修复后的JSON对象
                        tool_calls_data = json.loads(repaired_content)

                        # 检查是否使用新的JSON数组格式
                        if "tool_calls" in tool_calls_data:
                            # 新格式：使用JSON数组
                            tool_calls = tool_calls_data["tool_calls"]
                            if tool_calls is not None:
                                for tool_call in tool_calls:
                                    if (
                                        isinstance(tool_call, dict)
                                        and "tool_call" in tool_call
                                    ):
                                        actions.append(
                                            json.dumps(tool_call, ensure_ascii=False)
                                        )
                                    elif isinstance(tool_call, dict):
                                        # 尝试直接使用工具调用
                                        actions.append(
                                            json.dumps(
                                                {"tool_call": tool_call},
                                                ensure_ascii=False,
                                            )
                                        )
                        elif "tool_call" in tool_calls_data:
                            # 单个工具调用
                            actions.append(
                                json.dumps(tool_calls_data, ensure_ascii=False)
                            )
                        elif "action" in tool_calls_data or "args" in tool_calls_data:
                            # 旧格式：直接包含动作和参数
                            actions.append(
                                json.dumps(
                                    {"tool_call": tool_calls_data}, ensure_ascii=False
                                )
                            )
                    except Exception as e:
                        print(f"JSON修复和解析失败: {e}")
                        print(f"尝试解析的内容: \n{cleaned_content}\n")

                # 尝试从文本中提取多个JSON对象
                if not actions:
                    # 查找所有可能的JSON对象
                    json_pattern = r"\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\}))*\}))*\}"
                    json_matches = re.findall(json_pattern, cleaned_content)

                    for json_str in json_matches:
                        try:
                            # 使用json_repair修复JSON字符串
                            repaired_json = repair_json(json_str, ensure_ascii=False)

                            # 解析修复后的JSON对象
                            json_obj = json.loads(repaired_json)
                            if isinstance(json_obj, dict) and (
                                "tool_call" in json_obj or "action" in json_obj
                            ):
                                actions.append(repaired_json)
                        except Exception:
                            continue

                # 尝试解析动作分隔符格式
                if not actions and self.action_sep in cleaned_content:
                    action_parts = cleaned_content.split(self.action_sep)
                    for part in action_parts:
                        part = part.strip()
                        if part:
                            try:
                                # 检查是否可能是JSON格式
                                if part.startswith("{") and part.endswith("}"):
                                    # 尝试修复并解析为JSON对象
                                    repaired_part = repair_json(
                                        part, ensure_ascii=False
                                    )
                                    json_obj = json.loads(repaired_part)
                                    # 如果是有效的工具调用格式，直接使用
                                    if isinstance(json_obj, dict) and (
                                        "tool_call" in json_obj
                                        or "action" in json_obj
                                        or "name" in json_obj
                                    ):
                                        if "tool_call" in json_obj:
                                            actions.append(
                                                json.dumps(json_obj, ensure_ascii=False)
                                            )
                                        else:
                                            actions.append(
                                                json.dumps(
                                                    {"tool_call": json_obj},
                                                    ensure_ascii=False,
                                                )
                                            )
                                        continue
                            except Exception:
                                pass  # 如果不是有效JSON，继续使用默认格式

                            # 简单包装为动作格式
                            action_obj = {
                                "tool_call": {
                                    "name": "execute_action",
                                    "args": {"action": part},
                                }
                            }
                            actions.append(json.dumps(action_obj, ensure_ascii=False))

            except Exception as e:
                print(f"解析工具调用失败: {e}")
                print(f"原始响应: {response[:100]}...")

            # 限制动作数量
            max_actions = self.config.agent_proxy.max_actions_per_turn
            if len(actions) > max_actions:
                print(
                    f"警告: 动作数量 {len(actions)} 超过最大限制 {max_actions}，截断多余动作"
                )
                actions = actions[:max_actions]  # 只保留前MAX_ACTIONS个动作

            # 格式化LLM响应
            llm_response = (
                f"<think>{think_content}</think><answer>{answer_content}</answer>"
                if self.config.agent_proxy.enable_think
                else f"<answer>{answer_content}</answer>"
            )

        return llm_response, actions, think_content, answer_content

    def _normalize_score_tensor(
        self, score_tensor: torch.Tensor, env_outputs: List[Dict]
    ) -> torch.Tensor:
        """
        Normalize the score tensor to be between 0 and 1.
        NOTE: only support score at the last token for now
        """
        assert (
            self.config.agent_proxy.use_turn_scores == False
        ), "Reward normalization is not supported for use_turn_scores == True"

        rn_cfg = self.config.agent_proxy.reward_normalization
        grouping, method = rn_cfg.grouping, rn_cfg.method
        if grouping == "state":
            group_tags = [env_output["group_id"] for env_output in env_outputs]
        elif grouping == "inductive":
            group_tags = [env_output["tag"] for env_output in env_outputs]
        elif grouping == "batch":
            group_tags = [1] * len(env_outputs)
        else:
            raise ValueError(f"Invalid grouping: {grouping}")

        if method == "mean_std":
            norm_func = lambda x: (
                (x - x.mean(dim=-1, keepdim=True))
                / (x.std(dim=-1, keepdim=True) + 1e-6)
                if x.std(dim=-1, keepdim=True).abs().max() > 1e-6
                else torch.zeros_like(x)
            )  # stable to bf16 than x.std()
        elif method == "mean":
            norm_func = lambda x: (x - x.mean(dim=-1, keepdim=True))
        elif method == "asym_clip":
            norm_func = lambda x: (
                (x - x.mean(dim=-1, keepdim=True))
                / (x.std(dim=-1, keepdim=True) + 1e-6)
                if x.std(dim=-1, keepdim=True).abs().max() > 1e-6
                else torch.zeros_like(x)
            ).clamp(min=-1, max=3)
        elif method == "identity":
            norm_func = lambda x: x
        else:
            raise ValueError(f"Invalid normalization method: {method}")

        # apply groupwise normalization
        group2index = {}
        for i, env_tag in enumerate(group_tags):
            if env_tag not in group2index:
                group2index[env_tag] = []
            group2index[env_tag].append(i)
        group2index = {k: torch.tensor(v) for k, v in group2index.items()}

        # apply penalty pre-normalization
        acc_scores = score_tensor[:, -1]
        normalized_acc_scores = acc_scores.clone()
        penalty = torch.tensor(
            [env_output.get("penalty", 0) for env_output in env_outputs],
            dtype=torch.float32,
        )
        normalized_acc_scores = normalized_acc_scores + penalty

        if len(group2index) < acc_scores.shape[0]:  # the group size > 1
            for group, index in group2index.items():
                normalized_acc_scores[index] = norm_func(normalized_acc_scores[index])

        score_tensor[:, -1] = normalized_acc_scores

        return score_tensor

    def get_lm_inputs(
        self,
        env_outputs: List[Dict],
        prepare_for_update: bool,
        current_turn: Optional[int] = None,
    ) -> DataProto:
        """
        env_outputs - 见下方示例
        [
            {"env_id": 1, "history": [{"state": "###\n#x_#", "llm_response": "Response 1", "reward": 0.5}, {"state": "###\n#x_#"}]},
            {"env_id": 2, "history": [{"state": "###\n#x_#"}]},
            ...
        ]
        prefix_lookup - 从 env_id 到初始提示的映射
        """
        llm_input_texts = []
        messages_list = []  # for api calling
        if prepare_for_update:
            print("ok")
        for env_output in env_outputs:
            max_k = getattr(self.config.agent_proxy, "max_context_window", None)
            if max_k is not None and isinstance(max_k, int) and max_k > 0:
                env_output["history"] = env_output["history"][-max_k:]

            messages = [
                {"role": "system", "content": self.prefix_lookup[env_output["env_id"]]},
            ]

            for idx, content in enumerate(env_output["history"]):
                if "state" in content:
                    messages.append({"role": "user", "content": content['state']})

                if "llm_response" in content:
                    messages.append(
                        {"role": "assistant", "content": content["llm_response"]}
                    )

            # NOTE: this assertion is important for loss mask computation
            assert all(msg["role"] == "assistant" for msg in messages[2::2])

            text = self.tokenizer.apply_chat_template(
                messages, add_generation_prompt=(not prepare_for_update), tokenize=False
            )
            if not prepare_for_update:
                if self.config.agent_proxy.enable_think:
                    text += "<think>"  # 强制 LLM 在回答前进行思考
                else:
                    text += "<answer>"  # 强制 LLM 开始回答
            llm_input_texts.append(text)
            messages_list.append(messages)

        inputs = self.tokenizer(
            llm_input_texts,
            return_tensors="pt",
            padding=True,
            padding_side="left",
            truncation=False,
        )  # 不在此处截断，后续处理
        input_ids, attention_mask = inputs.input_ids, inputs.attention_mask
        position_ids = (attention_mask.cumsum(dim=-1) - 1).clamp(min=0)
        if prepare_for_update:
            scores = [
                [i.get("reward", 0.0) for i in env_output["history"]]
                for env_output in env_outputs
            ]
            score_tensor, loss_mask, response_mask = get_masks_and_scores(
                input_ids,
                self.tokenizer,
                scores,
                use_turn_scores=self.config.agent_proxy.use_turn_scores,
                enable_response_mask=self.config.enable_response_mask,
            )

            normalized_score_tensor = score_tensor
            if not self.config.agent_proxy.use_turn_scores:
                normalized_score_tensor = self._normalize_score_tensor(
                    score_tensor, env_outputs
                )
            response_length = response_mask.sum(dim=-1).float().mean().item()

        llm_inputs = DataProto()
        llm_inputs.batch = TensorDict(
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
                "responses": input_ids[:, 1:],  # remove the first token
            },
            batch_size=input_ids.shape[0],
        )

        if prepare_for_update:
            llm_inputs.batch["loss_mask"] = loss_mask  # remove the first token
            llm_inputs.batch["rm_scores"] = (
                normalized_score_tensor  # remove the first token
            )
            llm_inputs.batch["original_rm_scores"] = (
                score_tensor  # remove the first token
            )
        llm_inputs.non_tensor_batch = {
            "env_ids": np.array(
                [env_output["env_id"] for env_output in env_outputs], dtype=object
            ),
            "group_ids": np.array(
                [env_output["group_id"] for env_output in env_outputs], dtype=object
            ),
            "messages_list": np.array(messages_list, dtype=object),
            "env_infos": np.array(
                [env_output.get("env_info", {}) for env_output in env_outputs],
                dtype=object,
            ),
        }

        if prepare_for_update:
            metrics = {}
            for env_output in env_outputs:
                for key, value in env_output["metrics"].items():
                    if key not in metrics:
                        metrics[key] = []
                    metrics[key].append(value)
            mean_metrics = {
                key: np.sum(value) / self.env_nums[key.split("/")[0]]
                for key, value in metrics.items()
            }
            for key, values in metrics.items():
                if not isinstance(values, list):
                    continue
                prefix, suffix = key.split("/", 1)
                non_zero_values = [v for v in values if v != 0]
                if non_zero_values:  # Avoid division by zero
                    non_zero_key = f"{prefix}/non-zero/{suffix}"
                    mean_metrics[non_zero_key] = np.mean(non_zero_values)
            metrics = mean_metrics
            metrics["response_length"] = response_length
            llm_inputs.meta_info = {"metrics": metrics}
        return llm_inputs

    def get_env_inputs(self, lm_outputs: DataProto) -> List[Dict]:
        if lm_outputs.batch is not None and "responses" in lm_outputs.batch.keys():
            responses = self.tokenizer.batch_decode(
                lm_outputs.batch["responses"], skip_special_tokens=True
            )
        else:  # dataproto has textual responses
            responses = lm_outputs.non_tensor_batch["response_texts"]
        responses = [
            (
                "<think>" + response
                if self.config.agent_proxy.enable_think
                else "<answer>" + response
            )
            for response in responses
        ]  # The LLM generation does not include <think> tags. Add them back here.

        env_ids = lm_outputs.non_tensor_batch["env_ids"]
        env_inputs = []
        for env_id, response in zip(env_ids, responses):
            llm_response, actions, think_content, answer_content = self._parse_response(response)
            env_inputs.append(
                {
                    "env_id": env_id,
                    "llm_raw_response": response,
                    "llm_response": llm_response,
                    "actions": actions,
                    "think_content": think_content,
                    "answer_content": answer_content,
                }
            )
        return env_inputs

    def formulate_rollouts(self, env_outputs: List[Dict]) -> DataProto:
        llm_inputs = self.get_lm_inputs(env_outputs, prepare_for_update=True)
        return llm_inputs
